package com.ruoyi.common.utils.sql;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.ruoyi.common.exception.UtilException;
import com.ruoyi.common.req.QueryReq;
import com.ruoyi.common.utils.StringUtils;

import java.util.List;

/**
 * sql操作工具类
 *
 * @author ruoyi
 */
public class SqlUtil {
    /**
     * 定义常用的 sql关键字
     */
    public static String SQL_REGEX = "and |extractvalue|updatexml|exec |insert |select |delete |update |drop |count |chr |mid |master |truncate |char |declare |or |+|user()";

    /**
     * 仅支持字母、数字、下划线、空格、逗号、小数点（支持多个字段排序）
     */
    public static String SQL_PATTERN = "[a-zA-Z0-9_\\ \\,\\.]+";

    /**
     * 限制orderBy最大长度
     */
    private static final int ORDER_BY_MAX_LENGTH = 500;

    /**
     * 默认分页大小
     */
    private static final int DEFAULT_SIZE = 10;

    /**
     * 获取分页SQL
     *
     * @param req 搜索基类
     * @return
     */
    public static String getLimit(QueryReq req) {
        return getLimit(req.getCurrent(), req.getSize());
    }

    /**
     * 根据当前页和页面大小获取分页SQL
     *
     * @param current 当前页
     * @param size    页面大小
     * @return
     */
    public static String getLimit(Long current, Long size) {
        current = null == current ? 1L : current;
        size = null == size || size <= 0 ? DEFAULT_SIZE : size;
        Long start = (current - 1) * size;
        start = start < 0 ? 0 : start;
        StringBuilder sb = new StringBuilder("LIMIT ");
        return sb.append(start).append(",").append(size).toString();
    }

    /**
     * 根据页面大小获取分页SQL
     *
     * @param size 页面大小
     * @return
     */
    public static String getLimit(Long size) {
        return getLimit(null, size);
    }

    /**
     * 查询一个数据
     *
     * @return
     */
    public static String getALimit() {
        return getLimit(1L);
    }

    /**
     * 根据表字段属性按指定值排序SQL
     *
     * @param column   排序属性列
     * @param sortList 指定顺序
     * @return
     */
    public static String getFixedSortSql(String column, List<?> sortList) {
        if (StrUtil.isBlank(column) || CollUtil.isEmpty(sortList)) return "";
        StringBuilder sb = new StringBuilder(StrUtil.SPACE + "ORDER BY FIELD(");
        return sb.append(column).append(",").append(CollUtil.join(sortList, StrUtil.COMMA)).append(")").append(StrUtil.SPACE).toString();
    }

    /**
     * 检查字符，防止注入绕过
     */
    public static String escapeOrderBySql(String value) {
        if (StringUtils.isNotEmpty(value) && !isValidOrderBySql(value)) {
            throw new UtilException("参数不符合规范，不能进行查询");
        }
        if (StringUtils.length(value) > ORDER_BY_MAX_LENGTH) {
            throw new UtilException("参数已超过最大限制，不能进行查询");
        }
        return value;
    }

    /**
     * 验证 order by 语法是否符合规范
     */
    public static boolean isValidOrderBySql(String value) {
        return value.matches(SQL_PATTERN);
    }

    /**
     * SQL关键字检查
     */
    public static void filterKeyword(String value) {
        if (StringUtils.isEmpty(value)) {
            return;
        }
        String[] sqlKeywords = StringUtils.split(SQL_REGEX, "\\|");
        for (String sqlKeyword : sqlKeywords) {
            if (StringUtils.indexOfIgnoreCase(value, sqlKeyword) > -1) {
                throw new UtilException("参数存在SQL注入风险");
            }
        }
    }
}
